Lab 3ΒΆ
1. Obtain the datasetΒΆ
The dataset (https://www.kaggle.com/datasets/biaiscience/dogs-vs-cats, with Open Data license) contains images of dogs and cats, divided into training (1000 dogs, 1000 cats), validation (500 dogs, 500 cats), and test sets (1000 dogs, 1000 cats). Each image has a resolution of 180Γ180.
from keras.utils import image_dataset_from_directory
import pathlib
dataset_dir = pathlib.Path("./data/kaggle_dogs_vs_cats_small")
# load the dataset
train_dataset = image_dataset_from_directory(
dataset_dir / "train",
image_size=(180, 180),
batch_size=32,
label_mode="binary",
shuffle=True,
)
validation_dataset = image_dataset_from_directory(
dataset_dir / "validation",
image_size=(180, 180),
batch_size=32,
label_mode="binary",
shuffle=False,
)
test_dataset = image_dataset_from_directory(
dataset_dir / "test",
image_size=(180, 180),
batch_size=32,
label_mode="binary",
shuffle=False,
)
Found 2000 files belonging to 2 classes. Found 1000 files belonging to 2 classes. Found 2000 files belonging to 2 classes.
2. EDAΒΆ
2.1 Class distributionΒΆ
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
class_names = train_dataset.class_names
# Calculate the class distribution for each subset
def get_class_counts(dataset):
counts = {}
for _, labels in dataset:
unique, counts_unique = np.unique(labels, return_counts=True)
for u, c in zip(unique, counts_unique):
class_name = class_names[int(u)]
counts[class_name] = counts.get(class_name, 0) + c
return counts
train_counts = get_class_counts(train_dataset)
validation_counts = get_class_counts(validation_dataset)
test_counts = get_class_counts(test_dataset)
counts_df = pd.DataFrame(
{"Train": train_counts, "Validation": validation_counts, "Test": test_counts},
index=class_names,
)
print("Class distribution for each subset:")
counts_df
Class distribution for each subset:
| Train | Validation | Test | |
|---|---|---|---|
| cat | 1000 | 500 | 1000 |
| dog | 1000 | 500 | 1000 |
# Plot the class distribution
plt.figure(figsize=(6, 6))
counts_df.plot(kind="bar")
plt.title("Class distribution")
plt.show()
<Figure size 600x600 with 0 Axes>
2.2 Display sample imagesΒΆ
from matplotlib import pyplot as plt
plt.figure(figsize=(12, 10))
for images, labels in train_dataset.take(1):
for i in range(12):
ax = plt.subplot(3, 4, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[int(labels[i])])
plt.axis("off")
2.3 Image size and color channelsΒΆ
# Get the first image from the training dataset
for images, labels in train_dataset.take(1):
print(f"Image size: {images.shape[1:3]}")
print(f"Image color channels: {images.shape[3]}")
break
Image size: (180, 180) Image color channels: 3
2.4 Data augmentationΒΆ
from keras import layers
import keras
# Data augmentation
data_augmentation = keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.2),
]
)
# Display augmented images
plt.figure(figsize=(10, 10))
for images, _ in train_dataset.take(1):
for i in range(9):
augmented_images = data_augmentation(images)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_images[0].numpy().astype("uint8"))
plt.axis("off")
3. Training vanilla CNN modelΒΆ
Define a vanilla CNN model with data augmentation and dropout.
def create_vanilla_cnn_model():
inputs = keras.Input(shape=(180, 180, 3))
x = data_augmentation(inputs)
x = layers.Rescaling(1.0 / 255)(x)
x = layers.Conv2D(filters=32, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=64, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=128, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss="binary_crossentropy", optimizer="rmsprop", metrics=["accuracy"])
return model
vanilla_cnn_model = create_vanilla_cnn_model()
vanilla_cnn_model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 180, 180, 3)] 0
sequential (Sequential) (None, 180, 180, 3) 0
rescaling (Rescaling) (None, 180, 180, 3) 0
conv2d (Conv2D) (None, 178, 178, 32) 896
max_pooling2d (MaxPooling2D (None, 89, 89, 32) 0
)
conv2d_1 (Conv2D) (None, 87, 87, 64) 18496
max_pooling2d_1 (MaxPooling (None, 43, 43, 64) 0
2D)
conv2d_2 (Conv2D) (None, 41, 41, 128) 73856
max_pooling2d_2 (MaxPooling (None, 20, 20, 128) 0
2D)
conv2d_3 (Conv2D) (None, 18, 18, 256) 295168
max_pooling2d_3 (MaxPooling (None, 9, 9, 256) 0
2D)
conv2d_4 (Conv2D) (None, 7, 7, 256) 590080
flatten (Flatten) (None, 12544) 0
dense (Dense) (None, 256) 3211520
dropout (Dropout) (None, 256) 0
dense_1 (Dense) (None, 1) 257
=================================================================
Total params: 4,190,273
Trainable params: 4,190,273
Non-trainable params: 0
_________________________________________________________________
Train the vanilla CNN model.
vanilla_callbacks = [
keras.callbacks.ModelCheckpoint(
filepath="./models/vanilla_cnn_model.h5",
save_best_only=True,
monitor="val_loss",
)
]
vanilla_history = vanilla_cnn_model.fit(
train_dataset,
epochs=100,
validation_data=validation_dataset,
callbacks=vanilla_callbacks,
)
Epoch 1/100 63/63 [==============================] - 18s 281ms/step - loss: 0.7161 - accuracy: 0.4925 - val_loss: 0.6921 - val_accuracy: 0.5020 Epoch 2/100 63/63 [==============================] - 18s 285ms/step - loss: 0.6967 - accuracy: 0.5015 - val_loss: 0.6923 - val_accuracy: 0.6140 Epoch 3/100 63/63 [==============================] - 19s 296ms/step - loss: 0.6939 - accuracy: 0.5345 - val_loss: 0.6858 - val_accuracy: 0.6150 Epoch 4/100 63/63 [==============================] - 19s 297ms/step - loss: 0.6901 - accuracy: 0.5525 - val_loss: 0.7619 - val_accuracy: 0.5060 Epoch 5/100 63/63 [==============================] - 19s 301ms/step - loss: 0.6810 - accuracy: 0.5810 - val_loss: 0.6543 - val_accuracy: 0.6220 Epoch 6/100 63/63 [==============================] - 19s 304ms/step - loss: 0.6742 - accuracy: 0.6220 - val_loss: 0.6901 - val_accuracy: 0.5310 Epoch 7/100 63/63 [==============================] - 19s 300ms/step - loss: 0.6567 - accuracy: 0.6330 - val_loss: 0.6219 - val_accuracy: 0.6790 Epoch 8/100 63/63 [==============================] - 19s 309ms/step - loss: 0.6353 - accuracy: 0.6455 - val_loss: 0.9881 - val_accuracy: 0.5100 Epoch 9/100 63/63 [==============================] - 23s 360ms/step - loss: 0.6374 - accuracy: 0.6485 - val_loss: 0.5983 - val_accuracy: 0.6910 Epoch 10/100 63/63 [==============================] - 19s 307ms/step - loss: 0.6152 - accuracy: 0.6765 - val_loss: 0.5889 - val_accuracy: 0.7030 Epoch 11/100 63/63 [==============================] - 19s 300ms/step - loss: 0.6106 - accuracy: 0.6700 - val_loss: 0.5936 - val_accuracy: 0.6960 Epoch 12/100 63/63 [==============================] - 19s 301ms/step - loss: 0.5988 - accuracy: 0.6825 - val_loss: 0.6216 - val_accuracy: 0.6230 Epoch 13/100 63/63 [==============================] - 19s 301ms/step - loss: 0.5846 - accuracy: 0.6990 - val_loss: 0.5652 - val_accuracy: 0.7020 Epoch 14/100 63/63 [==============================] - 19s 308ms/step - loss: 0.5701 - accuracy: 0.7095 - val_loss: 0.5719 - val_accuracy: 0.6920 Epoch 15/100 63/63 [==============================] - 22s 344ms/step - loss: 0.5662 - accuracy: 0.7255 - val_loss: 0.5662 - val_accuracy: 0.7290 Epoch 16/100 63/63 [==============================] - 19s 296ms/step - loss: 0.5565 - accuracy: 0.7310 - val_loss: 0.5698 - val_accuracy: 0.7250 Epoch 17/100 63/63 [==============================] - 21s 340ms/step - loss: 0.5383 - accuracy: 0.7415 - val_loss: 0.5268 - val_accuracy: 0.7460 Epoch 18/100 63/63 [==============================] - 19s 306ms/step - loss: 0.5463 - accuracy: 0.7240 - val_loss: 0.4999 - val_accuracy: 0.7590 Epoch 19/100 63/63 [==============================] - 19s 298ms/step - loss: 0.5205 - accuracy: 0.7440 - val_loss: 0.5535 - val_accuracy: 0.7240 Epoch 20/100 63/63 [==============================] - 19s 297ms/step - loss: 0.5214 - accuracy: 0.7455 - val_loss: 0.5737 - val_accuracy: 0.6960 Epoch 21/100 63/63 [==============================] - 22s 355ms/step - loss: 0.5066 - accuracy: 0.7580 - val_loss: 0.6396 - val_accuracy: 0.7120 Epoch 22/100 63/63 [==============================] - 22s 355ms/step - loss: 0.5052 - accuracy: 0.7535 - val_loss: 0.5502 - val_accuracy: 0.7160 Epoch 23/100 63/63 [==============================] - 23s 360ms/step - loss: 0.5064 - accuracy: 0.7470 - val_loss: 0.4591 - val_accuracy: 0.7810 Epoch 24/100 63/63 [==============================] - 19s 307ms/step - loss: 0.4769 - accuracy: 0.7745 - val_loss: 0.4814 - val_accuracy: 0.7670 Epoch 25/100 63/63 [==============================] - 22s 354ms/step - loss: 0.4623 - accuracy: 0.7785 - val_loss: 0.5434 - val_accuracy: 0.7640 Epoch 26/100 63/63 [==============================] - 20s 310ms/step - loss: 0.4696 - accuracy: 0.7880 - val_loss: 0.4651 - val_accuracy: 0.7970 Epoch 27/100 63/63 [==============================] - 22s 349ms/step - loss: 0.4587 - accuracy: 0.7885 - val_loss: 0.4437 - val_accuracy: 0.7960 Epoch 28/100 63/63 [==============================] - 19s 305ms/step - loss: 0.4358 - accuracy: 0.7935 - val_loss: 0.6095 - val_accuracy: 0.7390 Epoch 29/100 63/63 [==============================] - 22s 356ms/step - loss: 0.4561 - accuracy: 0.7935 - val_loss: 0.6150 - val_accuracy: 0.7390 Epoch 30/100 63/63 [==============================] - 23s 369ms/step - loss: 0.4367 - accuracy: 0.7930 - val_loss: 0.5470 - val_accuracy: 0.7780 Epoch 31/100 63/63 [==============================] - 23s 366ms/step - loss: 0.4287 - accuracy: 0.8070 - val_loss: 0.5624 - val_accuracy: 0.7500 Epoch 32/100 63/63 [==============================] - 19s 303ms/step - loss: 0.4072 - accuracy: 0.8165 - val_loss: 0.4435 - val_accuracy: 0.8130 Epoch 33/100 63/63 [==============================] - 20s 312ms/step - loss: 0.4152 - accuracy: 0.8125 - val_loss: 0.4763 - val_accuracy: 0.7870 Epoch 34/100 63/63 [==============================] - 22s 353ms/step - loss: 0.3910 - accuracy: 0.8260 - val_loss: 0.6052 - val_accuracy: 0.7670 Epoch 35/100 63/63 [==============================] - 19s 305ms/step - loss: 0.3858 - accuracy: 0.8290 - val_loss: 0.4311 - val_accuracy: 0.8110 Epoch 36/100 63/63 [==============================] - 22s 350ms/step - loss: 0.3938 - accuracy: 0.8210 - val_loss: 0.7060 - val_accuracy: 0.7330 Epoch 37/100 63/63 [==============================] - 19s 307ms/step - loss: 0.3688 - accuracy: 0.8300 - val_loss: 0.5275 - val_accuracy: 0.7930 Epoch 38/100 63/63 [==============================] - 20s 310ms/step - loss: 0.3831 - accuracy: 0.8400 - val_loss: 0.5005 - val_accuracy: 0.8110 Epoch 39/100 63/63 [==============================] - 19s 301ms/step - loss: 0.3600 - accuracy: 0.8420 - val_loss: 0.5240 - val_accuracy: 0.7840 Epoch 40/100 63/63 [==============================] - 19s 306ms/step - loss: 0.3353 - accuracy: 0.8475 - val_loss: 0.6775 - val_accuracy: 0.7750 Epoch 41/100 63/63 [==============================] - 19s 307ms/step - loss: 0.3560 - accuracy: 0.8475 - val_loss: 0.4088 - val_accuracy: 0.8390 Epoch 42/100 63/63 [==============================] - 20s 311ms/step - loss: 0.3421 - accuracy: 0.8635 - val_loss: 0.5683 - val_accuracy: 0.8160 Epoch 43/100 63/63 [==============================] - 19s 305ms/step - loss: 0.3481 - accuracy: 0.8630 - val_loss: 0.4267 - val_accuracy: 0.8170 Epoch 44/100 63/63 [==============================] - 19s 298ms/step - loss: 0.3211 - accuracy: 0.8580 - val_loss: 0.6493 - val_accuracy: 0.7650 Epoch 45/100 63/63 [==============================] - 19s 297ms/step - loss: 0.3193 - accuracy: 0.8630 - val_loss: 0.4602 - val_accuracy: 0.8110 Epoch 46/100 63/63 [==============================] - 20s 317ms/step - loss: 0.2935 - accuracy: 0.8760 - val_loss: 0.7178 - val_accuracy: 0.7630 Epoch 47/100 63/63 [==============================] - 23s 359ms/step - loss: 0.3181 - accuracy: 0.8690 - val_loss: 0.6116 - val_accuracy: 0.7880 Epoch 48/100 63/63 [==============================] - 19s 307ms/step - loss: 0.3042 - accuracy: 0.8745 - val_loss: 0.6457 - val_accuracy: 0.8080 Epoch 49/100 63/63 [==============================] - 19s 301ms/step - loss: 0.2898 - accuracy: 0.8780 - val_loss: 0.5578 - val_accuracy: 0.7980 Epoch 50/100 63/63 [==============================] - 19s 301ms/step - loss: 0.2848 - accuracy: 0.8810 - val_loss: 0.4346 - val_accuracy: 0.8450 Epoch 51/100 63/63 [==============================] - 19s 305ms/step - loss: 0.2799 - accuracy: 0.8850 - val_loss: 0.5105 - val_accuracy: 0.8270 Epoch 52/100 63/63 [==============================] - 22s 348ms/step - loss: 0.2656 - accuracy: 0.8885 - val_loss: 0.4219 - val_accuracy: 0.8330 Epoch 53/100 63/63 [==============================] - 19s 294ms/step - loss: 0.2980 - accuracy: 0.8840 - val_loss: 0.4540 - val_accuracy: 0.8440 Epoch 54/100 63/63 [==============================] - 19s 294ms/step - loss: 0.2496 - accuracy: 0.8985 - val_loss: 0.6106 - val_accuracy: 0.8370 Epoch 55/100 63/63 [==============================] - 19s 308ms/step - loss: 0.2644 - accuracy: 0.9000 - val_loss: 0.5835 - val_accuracy: 0.8290 Epoch 56/100 63/63 [==============================] - 19s 306ms/step - loss: 0.2517 - accuracy: 0.8990 - val_loss: 0.4683 - val_accuracy: 0.8280 Epoch 57/100 63/63 [==============================] - 22s 351ms/step - loss: 0.2519 - accuracy: 0.8880 - val_loss: 0.4448 - val_accuracy: 0.8390 Epoch 58/100 63/63 [==============================] - 19s 303ms/step - loss: 0.2317 - accuracy: 0.9050 - val_loss: 0.4612 - val_accuracy: 0.8500 Epoch 59/100 63/63 [==============================] - 20s 310ms/step - loss: 0.2581 - accuracy: 0.9010 - val_loss: 0.4991 - val_accuracy: 0.8260 Epoch 60/100 63/63 [==============================] - 19s 299ms/step - loss: 0.2247 - accuracy: 0.9135 - val_loss: 0.5835 - val_accuracy: 0.8320 Epoch 61/100 63/63 [==============================] - 22s 346ms/step - loss: 0.2644 - accuracy: 0.9035 - val_loss: 0.6734 - val_accuracy: 0.7950 Epoch 62/100 63/63 [==============================] - 19s 296ms/step - loss: 0.2283 - accuracy: 0.9090 - val_loss: 0.4281 - val_accuracy: 0.8480 Epoch 63/100 63/63 [==============================] - 19s 299ms/step - loss: 0.2267 - accuracy: 0.9040 - val_loss: 0.8154 - val_accuracy: 0.7330 Epoch 64/100 63/63 [==============================] - 22s 348ms/step - loss: 0.2152 - accuracy: 0.9085 - val_loss: 0.7164 - val_accuracy: 0.8070 Epoch 65/100 63/63 [==============================] - 22s 354ms/step - loss: 0.1982 - accuracy: 0.9200 - val_loss: 0.6004 - val_accuracy: 0.8440 Epoch 66/100 63/63 [==============================] - 19s 298ms/step - loss: 0.2007 - accuracy: 0.9255 - val_loss: 0.6825 - val_accuracy: 0.8110 Epoch 67/100 63/63 [==============================] - 19s 298ms/step - loss: 0.2221 - accuracy: 0.9125 - val_loss: 0.5310 - val_accuracy: 0.8450 Epoch 68/100 63/63 [==============================] - 22s 345ms/step - loss: 0.2131 - accuracy: 0.9210 - val_loss: 0.7699 - val_accuracy: 0.8170 Epoch 69/100 63/63 [==============================] - 19s 295ms/step - loss: 0.2261 - accuracy: 0.9175 - val_loss: 0.4927 - val_accuracy: 0.8450 Epoch 70/100 63/63 [==============================] - 19s 297ms/step - loss: 0.2035 - accuracy: 0.9250 - val_loss: 0.5989 - val_accuracy: 0.8390 Epoch 71/100 63/63 [==============================] - 22s 348ms/step - loss: 0.2046 - accuracy: 0.9235 - val_loss: 0.5739 - val_accuracy: 0.8560 Epoch 72/100 63/63 [==============================] - 19s 306ms/step - loss: 0.1926 - accuracy: 0.9235 - val_loss: 0.6019 - val_accuracy: 0.8440 Epoch 73/100 63/63 [==============================] - 23s 361ms/step - loss: 0.2064 - accuracy: 0.9245 - val_loss: 0.9395 - val_accuracy: 0.7930 Epoch 74/100 63/63 [==============================] - 20s 310ms/step - loss: 0.1926 - accuracy: 0.9315 - val_loss: 0.9088 - val_accuracy: 0.7940 Epoch 75/100 63/63 [==============================] - 22s 352ms/step - loss: 0.1743 - accuracy: 0.9325 - val_loss: 0.7915 - val_accuracy: 0.8040 Epoch 76/100 63/63 [==============================] - 19s 296ms/step - loss: 0.2089 - accuracy: 0.9185 - val_loss: 0.5457 - val_accuracy: 0.7940 Epoch 77/100 63/63 [==============================] - 19s 303ms/step - loss: 0.1919 - accuracy: 0.9330 - val_loss: 0.8065 - val_accuracy: 0.7890 Epoch 78/100 63/63 [==============================] - 19s 303ms/step - loss: 0.1936 - accuracy: 0.9280 - val_loss: 0.6301 - val_accuracy: 0.8430 Epoch 79/100 63/63 [==============================] - 23s 367ms/step - loss: 0.1500 - accuracy: 0.9455 - val_loss: 0.7620 - val_accuracy: 0.8220 Epoch 80/100 63/63 [==============================] - 19s 305ms/step - loss: 0.1841 - accuracy: 0.9320 - val_loss: 0.8392 - val_accuracy: 0.8360 Epoch 81/100 63/63 [==============================] - 20s 312ms/step - loss: 0.1863 - accuracy: 0.9370 - val_loss: 0.5466 - val_accuracy: 0.8410 Epoch 82/100 63/63 [==============================] - 22s 351ms/step - loss: 0.1529 - accuracy: 0.9460 - val_loss: 1.2661 - val_accuracy: 0.8010 Epoch 83/100 63/63 [==============================] - 22s 356ms/step - loss: 0.1998 - accuracy: 0.9275 - val_loss: 0.6452 - val_accuracy: 0.8450 Epoch 84/100 63/63 [==============================] - 22s 344ms/step - loss: 0.1751 - accuracy: 0.9250 - val_loss: 0.6910 - val_accuracy: 0.8610 Epoch 85/100 63/63 [==============================] - 19s 296ms/step - loss: 0.1654 - accuracy: 0.9350 - val_loss: 0.6680 - val_accuracy: 0.8420 Epoch 86/100 63/63 [==============================] - 19s 299ms/step - loss: 0.1460 - accuracy: 0.9490 - val_loss: 0.8182 - val_accuracy: 0.8370 Epoch 87/100 63/63 [==============================] - 20s 311ms/step - loss: 0.1516 - accuracy: 0.9405 - val_loss: 0.9496 - val_accuracy: 0.8290 Epoch 88/100 63/63 [==============================] - 23s 366ms/step - loss: 0.1860 - accuracy: 0.9350 - val_loss: 0.6233 - val_accuracy: 0.8280 Epoch 89/100 63/63 [==============================] - 22s 356ms/step - loss: 0.1571 - accuracy: 0.9410 - val_loss: 0.9912 - val_accuracy: 0.8530 Epoch 90/100 63/63 [==============================] - 19s 301ms/step - loss: 0.1650 - accuracy: 0.9450 - val_loss: 0.8497 - val_accuracy: 0.8520 Epoch 91/100 63/63 [==============================] - 20s 310ms/step - loss: 0.1946 - accuracy: 0.9365 - val_loss: 0.5745 - val_accuracy: 0.8130 Epoch 92/100 63/63 [==============================] - 19s 307ms/step - loss: 0.1369 - accuracy: 0.9525 - val_loss: 0.7455 - val_accuracy: 0.8610 Epoch 93/100 63/63 [==============================] - 23s 364ms/step - loss: 0.1694 - accuracy: 0.9400 - val_loss: 0.5409 - val_accuracy: 0.8410 Epoch 94/100 63/63 [==============================] - 19s 304ms/step - loss: 0.1457 - accuracy: 0.9500 - val_loss: 0.9173 - val_accuracy: 0.8330 Epoch 95/100 63/63 [==============================] - 19s 301ms/step - loss: 0.1768 - accuracy: 0.9395 - val_loss: 0.6577 - val_accuracy: 0.8420 Epoch 96/100 63/63 [==============================] - 20s 312ms/step - loss: 0.1545 - accuracy: 0.9435 - val_loss: 0.5788 - val_accuracy: 0.8430 Epoch 97/100 63/63 [==============================] - 19s 304ms/step - loss: 0.1594 - accuracy: 0.9455 - val_loss: 1.0423 - val_accuracy: 0.8110 Epoch 98/100 63/63 [==============================] - 21s 331ms/step - loss: 0.1618 - accuracy: 0.9480 - val_loss: 0.7542 - val_accuracy: 0.8530 Epoch 99/100 63/63 [==============================] - 20s 314ms/step - loss: 0.1607 - accuracy: 0.9410 - val_loss: 0.7712 - val_accuracy: 0.8300 Epoch 100/100 63/63 [==============================] - 20s 310ms/step - loss: 0.1223 - accuracy: 0.9590 - val_loss: 0.7575 - val_accuracy: 0.8680
def plot_history(history):
accuracy = history.history["accuracy"]
val_accuracy = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs_range = range(1, len(accuracy) + 1)
# accuracy
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, accuracy, "bo", label="Training accuracy")
plt.plot(epochs_range, val_accuracy, "b", label="Validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Training and validation accuracy")
plt.legend(loc="upper left")
# loss
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, "bo", label="Training loss")
plt.plot(epochs_range, val_loss, "b", label="Validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and validation loss")
plt.legend(loc="upper left")
plt.show()
plot_history(vanilla_history)
4. Fine-tuning VGG16 modelΒΆ
Instantiate the VGG16 model with pre-trained ImageNet weights, remove the top layers.
from tensorflow import keras
conv_base = keras.applications.vgg16.VGG16(weights="imagenet", include_top=False)
conv_base.summary()
Model: "vgg16"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, None, None, 3)] 0
block1_conv1 (Conv2D) (None, None, None, 64) 1792
block1_conv2 (Conv2D) (None, None, None, 64) 36928
block1_pool (MaxPooling2D) (None, None, None, 64) 0
block2_conv1 (Conv2D) (None, None, None, 128) 73856
block2_conv2 (Conv2D) (None, None, None, 128) 147584
block2_pool (MaxPooling2D) (None, None, None, 128) 0
block3_conv1 (Conv2D) (None, None, None, 256) 295168
block3_conv2 (Conv2D) (None, None, None, 256) 590080
block3_conv3 (Conv2D) (None, None, None, 256) 590080
block3_pool (MaxPooling2D) (None, None, None, 256) 0
block4_conv1 (Conv2D) (None, None, None, 512) 1180160
block4_conv2 (Conv2D) (None, None, None, 512) 2359808
block4_conv3 (Conv2D) (None, None, None, 512) 2359808
block4_pool (MaxPooling2D) (None, None, None, 512) 0
block5_conv1 (Conv2D) (None, None, None, 512) 2359808
block5_conv2 (Conv2D) (None, None, None, 512) 2359808
block5_conv3 (Conv2D) (None, None, None, 512) 2359808
block5_pool (MaxPooling2D) (None, None, None, 512) 0
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
_________________________________________________________________
Define a new model using VGG16 as the base, freezing all layers except the top four layers, and add a custom head for fine-tuning. Additionally, use data augmentation.
def create_finetuned_model(conv_base):
# Freeze all layers until the fourth from the last.
conv_base.trainable = True
for layer in conv_base.layers[:-4]:
layer.trainable = False
inputs = keras.Input(shape=(180, 180, 3))
x = data_augmentation(inputs)
x = keras.applications.vgg16.preprocess_input(x)
x = conv_base(x)
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss="binary_crossentropy", optimizer=keras.optimizers.RMSprop(learning_rate=1e-5),
metrics=["accuracy"])
return model
finetuned_model = create_finetuned_model(conv_base)
finetuned_model.summary()
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) [(None, 180, 180, 3)] 0
sequential (Sequential) (None, 180, 180, 3) 0
tf.__operators__.getitem (S (None, 180, 180, 3) 0
licingOpLambda)
tf.nn.bias_add (TFOpLambda) (None, 180, 180, 3) 0
vgg16 (Functional) (None, None, None, 512) 14714688
flatten_1 (Flatten) (None, 12800) 0
dense_2 (Dense) (None, 256) 3277056
dropout_1 (Dropout) (None, 256) 0
dense_3 (Dense) (None, 1) 257
=================================================================
Total params: 17,992,001
Trainable params: 10,356,737
Non-trainable params: 7,635,264
_________________________________________________________________
Train the fine-tuned VGG16 model.
finetuned_callbacks = [
keras.callbacks.ModelCheckpoint(
filepath="./models/finetuned_vgg16_model.h5",
save_best_only=True,
monitor="val_loss",
)
]
finetuned_history = finetuned_model.fit(
train_dataset,
epochs=50,
validation_data=validation_dataset,
callbacks=finetuned_callbacks,
)
Epoch 1/50 63/63 [==============================] - 61s 958ms/step - loss: 3.1977 - accuracy: 0.6955 - val_loss: 0.5484 - val_accuracy: 0.8990 Epoch 2/50 63/63 [==============================] - 60s 963ms/step - loss: 1.1677 - accuracy: 0.8155 - val_loss: 0.2910 - val_accuracy: 0.9350 Epoch 3/50 63/63 [==============================] - 60s 963ms/step - loss: 0.6291 - accuracy: 0.8500 - val_loss: 0.1908 - val_accuracy: 0.9470 Epoch 4/50 63/63 [==============================] - 61s 974ms/step - loss: 0.3858 - accuracy: 0.8930 - val_loss: 0.1387 - val_accuracy: 0.9540 Epoch 5/50 63/63 [==============================] - 63s 996ms/step - loss: 0.2866 - accuracy: 0.9095 - val_loss: 0.1301 - val_accuracy: 0.9630 Epoch 6/50 63/63 [==============================] - 62s 982ms/step - loss: 0.2465 - accuracy: 0.9255 - val_loss: 0.1038 - val_accuracy: 0.9710 Epoch 7/50 63/63 [==============================] - 60s 956ms/step - loss: 0.1903 - accuracy: 0.9345 - val_loss: 0.0963 - val_accuracy: 0.9720 Epoch 8/50 63/63 [==============================] - 61s 965ms/step - loss: 0.1464 - accuracy: 0.9500 - val_loss: 0.1013 - val_accuracy: 0.9720 Epoch 9/50 63/63 [==============================] - 61s 967ms/step - loss: 0.1593 - accuracy: 0.9510 - val_loss: 0.0958 - val_accuracy: 0.9700 Epoch 10/50 63/63 [==============================] - 60s 952ms/step - loss: 0.1353 - accuracy: 0.9560 - val_loss: 0.0916 - val_accuracy: 0.9760 Epoch 11/50 63/63 [==============================] - 59s 944ms/step - loss: 0.1035 - accuracy: 0.9625 - val_loss: 0.0853 - val_accuracy: 0.9770 Epoch 12/50 63/63 [==============================] - 60s 956ms/step - loss: 0.1059 - accuracy: 0.9685 - val_loss: 0.0833 - val_accuracy: 0.9790 Epoch 13/50 63/63 [==============================] - 59s 940ms/step - loss: 0.0844 - accuracy: 0.9730 - val_loss: 0.0927 - val_accuracy: 0.9770 Epoch 14/50 63/63 [==============================] - 60s 955ms/step - loss: 0.0813 - accuracy: 0.9745 - val_loss: 0.0857 - val_accuracy: 0.9800 Epoch 15/50 63/63 [==============================] - 60s 949ms/step - loss: 0.0882 - accuracy: 0.9715 - val_loss: 0.0874 - val_accuracy: 0.9770 Epoch 16/50 63/63 [==============================] - 75s 1s/step - loss: 0.0864 - accuracy: 0.9735 - val_loss: 0.0769 - val_accuracy: 0.9800 Epoch 17/50 63/63 [==============================] - 61s 965ms/step - loss: 0.0577 - accuracy: 0.9785 - val_loss: 0.0735 - val_accuracy: 0.9810 Epoch 18/50 63/63 [==============================] - 61s 978ms/step - loss: 0.0729 - accuracy: 0.9800 - val_loss: 0.0844 - val_accuracy: 0.9790 Epoch 19/50 63/63 [==============================] - 60s 957ms/step - loss: 0.0419 - accuracy: 0.9860 - val_loss: 0.0828 - val_accuracy: 0.9810 Epoch 20/50 63/63 [==============================] - 60s 957ms/step - loss: 0.0484 - accuracy: 0.9825 - val_loss: 0.1385 - val_accuracy: 0.9750 Epoch 21/50 63/63 [==============================] - 60s 951ms/step - loss: 0.0532 - accuracy: 0.9835 - val_loss: 0.0828 - val_accuracy: 0.9820 Epoch 22/50 63/63 [==============================] - 60s 952ms/step - loss: 0.0437 - accuracy: 0.9835 - val_loss: 0.0845 - val_accuracy: 0.9810 Epoch 23/50 63/63 [==============================] - 60s 952ms/step - loss: 0.0564 - accuracy: 0.9825 - val_loss: 0.0848 - val_accuracy: 0.9800 Epoch 24/50 63/63 [==============================] - 59s 944ms/step - loss: 0.0502 - accuracy: 0.9835 - val_loss: 0.0926 - val_accuracy: 0.9800 Epoch 25/50 63/63 [==============================] - 59s 943ms/step - loss: 0.0426 - accuracy: 0.9855 - val_loss: 0.0877 - val_accuracy: 0.9840 Epoch 26/50 63/63 [==============================] - 59s 940ms/step - loss: 0.0392 - accuracy: 0.9885 - val_loss: 0.0934 - val_accuracy: 0.9820 Epoch 27/50 63/63 [==============================] - 59s 936ms/step - loss: 0.0280 - accuracy: 0.9885 - val_loss: 0.0954 - val_accuracy: 0.9830 Epoch 28/50 63/63 [==============================] - 59s 940ms/step - loss: 0.0244 - accuracy: 0.9920 - val_loss: 0.1896 - val_accuracy: 0.9690 Epoch 29/50 63/63 [==============================] - 59s 942ms/step - loss: 0.0424 - accuracy: 0.9830 - val_loss: 0.1092 - val_accuracy: 0.9760 Epoch 30/50 63/63 [==============================] - 59s 946ms/step - loss: 0.0268 - accuracy: 0.9910 - val_loss: 0.1049 - val_accuracy: 0.9800 Epoch 31/50 63/63 [==============================] - 60s 952ms/step - loss: 0.0202 - accuracy: 0.9925 - val_loss: 0.1023 - val_accuracy: 0.9820 Epoch 32/50 63/63 [==============================] - 60s 948ms/step - loss: 0.0181 - accuracy: 0.9960 - val_loss: 0.1045 - val_accuracy: 0.9810 Epoch 33/50 63/63 [==============================] - 60s 955ms/step - loss: 0.0150 - accuracy: 0.9940 - val_loss: 0.1755 - val_accuracy: 0.9760 Epoch 34/50 63/63 [==============================] - 60s 962ms/step - loss: 0.0247 - accuracy: 0.9920 - val_loss: 0.1047 - val_accuracy: 0.9810 Epoch 35/50 63/63 [==============================] - 61s 978ms/step - loss: 0.0247 - accuracy: 0.9930 - val_loss: 0.1019 - val_accuracy: 0.9850 Epoch 36/50 63/63 [==============================] - 65s 1s/step - loss: 0.0199 - accuracy: 0.9930 - val_loss: 0.1013 - val_accuracy: 0.9830 Epoch 37/50 63/63 [==============================] - 80s 1s/step - loss: 0.0152 - accuracy: 0.9960 - val_loss: 0.1201 - val_accuracy: 0.9810 Epoch 38/50 63/63 [==============================] - 61s 976ms/step - loss: 0.0232 - accuracy: 0.9935 - val_loss: 0.1766 - val_accuracy: 0.9710 Epoch 39/50 63/63 [==============================] - 59s 945ms/step - loss: 0.0150 - accuracy: 0.9965 - val_loss: 0.1140 - val_accuracy: 0.9810 Epoch 40/50 63/63 [==============================] - 60s 948ms/step - loss: 0.0127 - accuracy: 0.9950 - val_loss: 0.1291 - val_accuracy: 0.9840 Epoch 41/50 63/63 [==============================] - 59s 946ms/step - loss: 0.0187 - accuracy: 0.9950 - val_loss: 0.1645 - val_accuracy: 0.9770 Epoch 42/50 63/63 [==============================] - 60s 957ms/step - loss: 0.0261 - accuracy: 0.9925 - val_loss: 0.1297 - val_accuracy: 0.9820 Epoch 43/50 63/63 [==============================] - 61s 969ms/step - loss: 0.0203 - accuracy: 0.9935 - val_loss: 0.1119 - val_accuracy: 0.9830 Epoch 44/50 63/63 [==============================] - 60s 952ms/step - loss: 0.0081 - accuracy: 0.9965 - val_loss: 0.1294 - val_accuracy: 0.9860 Epoch 45/50 63/63 [==============================] - 61s 965ms/step - loss: 0.0048 - accuracy: 0.9985 - val_loss: 0.2304 - val_accuracy: 0.9740 Epoch 46/50 63/63 [==============================] - 60s 957ms/step - loss: 0.0211 - accuracy: 0.9945 - val_loss: 0.1422 - val_accuracy: 0.9810 Epoch 47/50 63/63 [==============================] - 78s 1s/step - loss: 0.0088 - accuracy: 0.9965 - val_loss: 0.1275 - val_accuracy: 0.9820 Epoch 48/50 63/63 [==============================] - 62s 980ms/step - loss: 0.0112 - accuracy: 0.9970 - val_loss: 0.1464 - val_accuracy: 0.9780 Epoch 49/50 63/63 [==============================] - 60s 961ms/step - loss: 0.0075 - accuracy: 0.9970 - val_loss: 0.1258 - val_accuracy: 0.9830 Epoch 50/50 63/63 [==============================] - 60s 953ms/step - loss: 0.0173 - accuracy: 0.9945 - val_loss: 0.1304 - val_accuracy: 0.9820
plot_history(finetuned_history)
5. Evaluate the modelsΒΆ
Define functions to evaluate the models and display the results.
import keras
from sklearn.metrics import classification_report, precision_recall_curve, auc
import tensorflow as tf
def get_predictions(model, dataset):
y_true_all = []
y_pred_all = []
y_pred_bin_all = []
for images, y_true in dataset:
y_pred = model.predict(images)
y_pred_bin = (y_pred > 0.5)
y_true_all.extend(y_true)
y_pred_all.extend(y_pred)
y_pred_bin_all.extend(y_pred_bin)
return (np.array(y_true_all).flatten().astype(int),
np.array(y_pred_all).flatten(),
np.array(y_pred_bin_all).flatten().astype(int))
def calc_confusion_matrix(y_true, y_pred_bin):
cm = tf.math.confusion_matrix(y_true, y_pred_bin)
TP = cm[1][1]
TN = cm[0][0]
FP = cm[0][1]
FN = cm[1][0]
return TP, TN, FP, FN
def plot_precision_recall_curve(y_true, y_pred):
precision, recall, _ = precision_recall_curve(y_true, y_pred)
pr_auc = auc(recall, precision)
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, marker=".", label=f"AUC: {pr_auc:.3f}")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend(loc='lower left')
plt.grid()
plt.show()
def evaluate_model(model_path, dataset):
model = keras.models.load_model(model_path)
y_true, y_pred, y_pred_bin = get_predictions(model, dataset)
TP, TN, FP, FN = calc_confusion_matrix(y_true, y_pred_bin)
result = f"\n{classification_report(y_true, y_pred_bin, target_names=class_names)}\n"
result += f"\nConfusion matrix:\n"
result += f"{'-' * 25}\n"
result += f"| TP: {TP:5} | FP: {FP:5} |\n"
result += f"{'-' * 25}\n"
result += f"| FN: {FN:5} | TN: {TN:5} |\n"
result += f"{'-' * 25}"
return result, y_true, y_pred, y_pred_bin
def show_misclassified_images(y_true, y_pred_bin, dataset, max_num=5):
errors = np.where(y_true != y_pred_bin)[0]
max_cols = 5
num_images = min(len(errors), max_num)
rows = (num_images + max_cols - 1) // max_cols
selected_errors = np.random.choice(errors, num_images, replace=False)
fig, axes = plt.subplots(rows, max_cols, figsize=(15, 4 * rows))
axes = axes.flatten()
for i, idx in enumerate(selected_errors):
image, label = list(dataset.unbatch().as_numpy_iterator())[idx]
prediction = class_names[y_pred_bin[idx]]
axes[i].imshow(image.astype("uint8"))
axes[i].set_title(f"True: {class_names[int(label)]} Pred: {prediction}")
axes[i].axis("off")
# Hide the remaining axes
for i in range(num_images, len(axes)):
axes[i].axis("off")
plt.suptitle("Misclassified images")
plt.tight_layout()
plt.show()
5.1 Vanilla CNN modelΒΆ
# Evaluate the vanilla CNN model
(vanilla_report,
y_true_vanilla,
y_pred_vanilla,
y_pred_bin_vanilla) = evaluate_model("./models/vanilla_cnn_model.h5", test_dataset)
1/1 [==============================] - 0s 136ms/step 1/1 [==============================] - 0s 136ms/step 1/1 [==============================] - 0s 56ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 50ms/step 1/1 [==============================] - 0s 54ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 54ms/step 1/1 [==============================] - 0s 57ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 55ms/step 1/1 [==============================] - 0s 63ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 56ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 61ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 56ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 54ms/step 1/1 [==============================] - 0s 63ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 54ms/step 1/1 [==============================] - 0s 56ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 57ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 57ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 54ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 54ms/step 1/1 [==============================] - 0s 58ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 67ms/step 1/1 [==============================] - 0s 61ms/step 1/1 [==============================] - 0s 54ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 54ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 53ms/step 1/1 [==============================] - 0s 65ms/step
print(f"Vanilla CNN model:\n{vanilla_report}")
Vanilla CNN model:
precision recall f1-score support
cat 0.83 0.81 0.82 1000
dog 0.82 0.83 0.82 1000
accuracy 0.82 2000
macro avg 0.82 0.82 0.82 2000
weighted avg 0.82 0.82 0.82 2000
Confusion matrix:
-------------------------
| TP: 831 | FP: 185 |
-------------------------
| FN: 169 | TN: 815 |
-------------------------
# plot the precision-recall curve
plot_precision_recall_curve(y_true_vanilla, y_pred_vanilla)
#γShow misclassified images
show_misclassified_images(y_true_vanilla, y_pred_bin_vanilla, test_dataset, 10)
5.2 Fine-tuned VGG16 modelΒΆ
# Evaluate the fine-tuned VGG 16 model
(finetuned_report,
y_true_finetuned,
y_pred_finetuned,
y_pred_bin_finetuned) = evaluate_model("./models/finetuned_vgg16_model.h5", test_dataset)
1/1 [==============================] - 1s 540ms/step 1/1 [==============================] - 1s 540ms/step 1/1 [==============================] - 1s 510ms/step 1/1 [==============================] - 0s 482ms/step 1/1 [==============================] - 0s 493ms/step 1/1 [==============================] - 0s 474ms/step 1/1 [==============================] - 0s 486ms/step 1/1 [==============================] - 0s 487ms/step 1/1 [==============================] - 0s 480ms/step 1/1 [==============================] - 0s 477ms/step 1/1 [==============================] - 0s 490ms/step 1/1 [==============================] - 0s 472ms/step 1/1 [==============================] - 0s 480ms/step 1/1 [==============================] - 0s 485ms/step 1/1 [==============================] - 0s 485ms/step 1/1 [==============================] - 0s 464ms/step 1/1 [==============================] - 1s 501ms/step 1/1 [==============================] - 0s 472ms/step 1/1 [==============================] - 0s 499ms/step 1/1 [==============================] - 0s 488ms/step 1/1 [==============================] - 0s 481ms/step 1/1 [==============================] - 0s 473ms/step 1/1 [==============================] - 0s 487ms/step 1/1 [==============================] - 0s 481ms/step 1/1 [==============================] - 1s 520ms/step 1/1 [==============================] - 0s 484ms/step 1/1 [==============================] - 0s 491ms/step 1/1 [==============================] - 0s 471ms/step 1/1 [==============================] - 0s 489ms/step 1/1 [==============================] - 0s 478ms/step 1/1 [==============================] - 0s 488ms/step 1/1 [==============================] - 0s 486ms/step 1/1 [==============================] - 0s 486ms/step 1/1 [==============================] - 0s 469ms/step 1/1 [==============================] - 0s 496ms/step 1/1 [==============================] - 0s 471ms/step 1/1 [==============================] - 0s 480ms/step 1/1 [==============================] - 0s 486ms/step 1/1 [==============================] - 1s 504ms/step 1/1 [==============================] - 0s 492ms/step 1/1 [==============================] - 1s 513ms/step 1/1 [==============================] - 0s 485ms/step 1/1 [==============================] - 0s 489ms/step 1/1 [==============================] - 0s 486ms/step 1/1 [==============================] - 1s 507ms/step 1/1 [==============================] - 0s 490ms/step 1/1 [==============================] - 0s 498ms/step 1/1 [==============================] - 0s 482ms/step 1/1 [==============================] - 0s 493ms/step 1/1 [==============================] - 0s 475ms/step 1/1 [==============================] - 0s 481ms/step 1/1 [==============================] - 0s 475ms/step 1/1 [==============================] - 0s 488ms/step 1/1 [==============================] - 0s 483ms/step 1/1 [==============================] - 0s 484ms/step 1/1 [==============================] - 0s 482ms/step 1/1 [==============================] - 0s 484ms/step 1/1 [==============================] - 0s 483ms/step 1/1 [==============================] - 0s 490ms/step 1/1 [==============================] - 0s 483ms/step 1/1 [==============================] - 0s 490ms/step 1/1 [==============================] - 0s 486ms/step 1/1 [==============================] - 0s 484ms/step 1/1 [==============================] - 0s 290ms/step
print(f"Fine-tuned VGG16 model:\n{finetuned_report}")
Fine-tuned VGG16 model:
precision recall f1-score support
cat 0.97 0.98 0.98 1000
dog 0.98 0.97 0.98 1000
accuracy 0.98 2000
macro avg 0.98 0.98 0.98 2000
weighted avg 0.98 0.98 0.98 2000
Confusion matrix:
-------------------------
| TP: 967 | FP: 16 |
-------------------------
| FN: 33 | TN: 984 |
-------------------------
# plot the precision-recall curve
plot_precision_recall_curve(y_true_finetuned, y_pred_finetuned)
# Show misclassified images
show_misclassified_images(y_true_finetuned, y_pred_bin_finetuned, test_dataset, 10)
6. ConclusionΒΆ
In this lab, I trained two different models for the cat and dog classification task:
- Vanilla Convolutional Neural Network (CNN):
- Advantages: Simple structure and shorter training time.
- Disadvantages: Lower performance in terms of accuracy and other evaluation metrics compared to the fine-tuned VGG16 model, and more prone to overfitting. Overfitting occurred around the 35th epoch. The precision-recall curve shows a relatively smaller area under the curve (AUC = 0.897), indicating that the model struggles with maintaining high precision and recall.
- Fine-tuned pre-trained VGG16 model:
Advantages: Utilized the feature extraction capabilities of the pre-trained model, achieving better performance in accuracy, precision, recall, and F1 score. After fine-tuning, the model demonstrated better generalization on the test set. The precision-recall curve for this model highlights its superior performance, with an almost perfect AUC of 0.998, reflecting its excellent ability to distinguish between classes consistently.
Disadvantages: Longer training time and higher computational resource requirements, with the training time per epoch being three times that of the vanilla CNN.
To further improve the model's performance, the following measures can be taken:
- Optimize hyperparameters such as learning rate, batch size, and optimizer.
- Experiment with more complex pre-trained models (e.g., ResNet) for fine-tuning or use ensemble learning by combining multiple models.
- Increase the size of the dataset to enhance generalization and reduce overfitting risks.
Overall, the fine-tuned VGG16 model performed exceptionally well in this task, demonstrating the powerful capabilities in image classification, particularly when fine-tuned with domain-specific data.